package com.njupt.wuaiagent.rag;

import org.springframework.ai.chat.client.ChatClient;
import org.springframework.ai.chat.model.ChatModel;
import org.springframework.ai.rag.Query;
import org.springframework.ai.rag.preretrieval.query.transformation.QueryTransformer;
import org.springframework.ai.rag.preretrieval.query.transformation.RewriteQueryTransformer;
import org.springframework.stereotype.Component;

/**
 * @Author: wujiaming
 * @CreateTime: 2025/5/17 20:39
 * @Description: 查询重写
 * @Version: 1.0
 */


@Component
public class QueryRewriter {

    private final QueryTransformer queryTransformer;

    public QueryRewriter(ChatModel dashscopeChatModel) {
        ChatClient.Builder builder = ChatClient.builder(dashscopeChatModel);
        //创建查询重写转换器
        queryTransformer = RewriteQueryTransformer.builder()
                .chatClientBuilder(builder)
                .build();
    }

    public String doQueryRewriter(String prompt){
        Query query = new Query(prompt);
        //执行查询重写
        Query transformQuery = queryTransformer.transform(query);
        //输出查询后的结果
        return transformQuery.text();

    }
}
